JuliaCon 2022
Patrick Altmeyer
CounterfactualExplanations.jl.From human to data-driven decision-making …
… where black boxes are recipe for disaster.
“You cannot appeal to (algorithms). They do not listen. Nor do they bend.”
— Cathy O’Neil in Weapons of Math Destruction, 2016
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
Let \(\mathcal{D}={(x,y)}\) denote our true population of input-output pairs. Then we want to find a subsample of the true population
\[\mathcal{D}_n \subset \mathcal{D}\]
such that
\[\mathcal{D}_n \sim p(\mathcal{D})\]
Lots of open questions and work to be done, but not here and today.
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
Let \(p(\mathcal{D}_n|\theta)\) denote the likelihood of observing our subsample \(\mathcal{D}_n\) under some model parameterized by \(\theta\). Then we typically want to maximize this likelihood with respect to the parameters (Murphy 2022):
\[\arg \max_{\theta} p(\mathcal{D}_n|\theta)\]
[…] deep neural networks are typically very underspecified by the available data, and […] parameters [therefore] correspond to a diverse variety of compelling explanations for the data. (Wilson 2020)
In this setting it is often crucial to treat models probabilistically!
Probabilistic models covered briefly today. More in my other talk
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
We can now make predictions - great! But do we know how the predictions are actually made?
Let \(\hat\theta\) denote our MLE estimate (or MAP in the probabilistic setting). Then we are interested in understanding how predictions of our model change with respect to input changes.
\[\nabla_x p(y|x,\mathcal{D}_n)\]
Even though […] interpretability is of great importance and should be pursued, explanations can, in principle, be offered without opening the “black box”. (Wachter, Mittelstadt, and Russell 2017)
Objective originally proposed by Wachter, Mittelstadt, and Russell (2017) is as follows
\[ \min_{x\prime \in \mathcal{X}} h(x\prime) \ \ \ \mbox{s. t.} \ \ \ M(x\prime) = t \qquad(1)\]
where \(h\) relates to the complexity of the counterfactual and \(M\) denotes the classifier.
Typically this is approximated through regularization:
\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) + \lambda h(x\prime) \qquad(2)\]
Yes and no!
While both are methodologically very similar, adversarial examples are meant to go undetected while CEs ought to be meaningful.
Effective counterfactuals should meet certain criteria ✅
NO!
Causal inference: counterfactuals are thought of as unobserved states of the world that we would like to observe in order to establish causality.
Counterfactual Explanations: involves perturbing features after some model has been trained.
But still … there is an intriguing link between the two domains.
When people say that counterfactuals should look realistic or plausible, they really mean that counterfactuals should be generated by the same Data Generating Process (DGP) as the factuals:
\[ x\prime \sim p(x) \]
But how do we estimate \(p(x)\)? Two probabilistic approaches …
Schut et al. (2021) note that by maximizing predictive probabilities \(\sigma(M(x\prime))\) for probabilistic models \(M\in\mathcal{\widetilde{M}}\) one implicitly minimizes epistemic and aleotoric uncertainty.
\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) \ \ \ , \ \ \ M\in\mathcal{\widetilde{M}} \qquad(3)\]
Instead of perturbing samples directly, some have proposed to instead traverse a lower-dimensional latent embedding learned through a generative model (Joshi et al. 2019).
\[ z\prime = \arg \min_{z\prime} \ell(M(dec(z\prime)),t) + \lambda h(x\prime) \qquad(4)\]
and
\[x\prime = dec(z\prime)\]
where \(dec(\cdot)\) is the decoder function.
Work currently scattered across different GitHub repositories …
CounterfactualExplanations.jl 📦… until now!
Julia has an edge with respect to Trustworthy AI: it’s open-source, uniquely transparent and interoperable 🔴🟢🟣
Modular, composable, scalable!
Figure 6: Overview of package architecture. Modules are shown in red, structs in green and functions in blue.
using CounterfactualExplanations, Plots, GraphRecipes
plt = plot(AbstractGenerator, method=:tree, fontsize=10, nodeshape=:rect, size=(1000,700))
savefig(plt, joinpath(www_path,"generators.png"))Figure 7: Type tree for AbstractGenerator.
plt = plot(AbstractFittedModel, method=:tree, fontsize=10, nodeshape=:rect, size=(1000,700))
savefig(plt, joinpath(www_path,"models.png"))Figure 8: Type tree for AbstractFittedModel.
We begin by instatiating the fitted model …
… then based on its prediction for \(x\) we choose the opposite label as our target …
… et voilà!
GenericGenerator. The contour (left) shows the predicted probabilities of the classifier (Logistic Regression).This time we use a Bayesian classifier …
… and once again choose our target label as before …
In this case the Bayesian approach yields a similar outcome.
GreedyGenerator. The contour (left) shows the predicted probabilities of the classifier (Bayesian Logistic Regression).Using the same classifier as before we can either use the specific REVISEGenerator …
# Counterfactual search:
generator = REVISEGenerator()
counterfactual = generate_counterfactual(
x, target, counterfactual_data, M, generator
)… or realize that that REVISE (Joshi et al. 2019) just boils down to generic search in a latent space:
We have essentially combined latent search with a probabilisitc classifier (as in Antorán et al. (2020)).
REVISEGenerator.… instantiating model and attaching VAE.
The results in Figure 13 look great!
But things can also go wrong …
The VAE used to generate the counterfactual in Figure 14 is not expressive enough.
The counterfactual in Figure 15 is also valid … what to do?
Step 1: add composite type as subtype of AbstractFittedModel.
Step 2: dispatch logits and probs methods for new model type.
using Statistics
import CounterfactualExplanations.Models: logits, probs
logits(M::FittedEnsemble, X::AbstractArray) = mean(Flux.stack([nn(X) for nn in M.ensemble],3), dims=3)
probs(M::FittedEnsemble, X::AbstractArray) = mean(Flux.stack([softmax(nn(X)) for nn in M.ensemble],3),dims=3)
M = FittedEnsemble(ensemble)Results for a simple deep ensemble also look convincing!
Adding support for torch models was easy! Here’s how I implemented it for torch classifiers trained in R.
Step 1: add composite type as subtype of AbstractFittedModel
Done here.
Step 2: dispatch logits and probs methods for new model type.
Done here.
Step 3: add gradient access.
Done here.
M = RTorchModel(model)
# Select target class:
y = round(probs(M, x)[1])
target = ifelse(y==1.0,0.0,1.0) # opposite label as target
# Define generator:
generator = GenericGenerator()
# Generate recourse:
counterfactual = generate_counterfactual(
x, target, counterfactual_data, M, generator
)GenericGenerator and RTorchModel.Idea 💡: let’s implement a generic generator with dropout!
Step 1: create a subtype of AbstractGradientBasedGenerator (adhering to some basic rules).
# Constructor:
abstract type AbstractDropoutGenerator <: AbstractGradientBasedGenerator end
struct DropoutGenerator <: AbstractDropoutGenerator
loss::Symbol # loss function
complexity::Function # complexity function
mutability::Union{Nothing,Vector{Symbol}} # mutibility constraints
λ::AbstractFloat # strength of penalty
ϵ::AbstractFloat # step size
τ::AbstractFloat # tolerance for convergence
p_dropout::AbstractFloat # dropout rate
endStep 2: implement logic for generating perturbations.
import CounterfactualExplanations.Generators: generate_perturbations, ∇
using StatsBase
function generate_perturbations(generator::AbstractDropoutGenerator, counterfactual_state::State)
𝐠ₜ = ∇(generator, counterfactual_state.M, counterfactual_state) # gradient
# Dropout:
set_to_zero = sample(1:length(𝐠ₜ),Int(round(generator.p_dropout*length(𝐠ₜ))),replace=false)
𝐠ₜ[set_to_zero] .= 0
Δx′ = - (generator.ϵ .* 𝐠ₜ) # gradient step
return Δx′
end# Instantiate:
using LinearAlgebra
generator = DropoutGenerator(
:logitbinarycrossentropy,
norm,
nothing,
0.1,
0.1,
1e-5,
0.5
)
counterfactual = generate_counterfactual(
x, target, counterfactual_data, M, generator
)DropoutGenerator and RTorchModel.Develop package, register and submit to JuliaCon 2022.
Native support for deep learning models (Flux, torch).
Add latent space search.
MLJ, GLM, …Flux optimizers.Explaining Black-Box Models through Counterfactuals – JuliaCon 2022 – Patrick Altmeyer